{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transfer Learning with skorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, you will learn how to train a neutral network using transfer learning with the `skorch` API. Transfer learning uses a pretrained model to initialize a network. This tutorial converts the pure PyTorch approach described in [PyTorch's Transfer Learning Tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) to `skorch`.\n",
    "\n",
    "We will be using `torchvision` for this tutorial. Instructions on how to install `torchvision` for your platform can be found at https://pytorch.org.\n",
    "\n",
    "<table align=\"left\"><td>\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb\">\n",
    "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>  \n",
    "</td><td>\n",
    "<a target=\"_blank\" href=\"https://github.com/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb), we recommend you enable a free GPU by going:\n",
    "\n",
    "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n",
    "\n",
    "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "! [ ! -z \"$COLAB_GPU\" ] && pip install torch torchvision pillow==4.1.1 skorch\n",
    "! [ ! -z \"$COLAB_GPU\" ] && mkdir -p datasets\n",
    "! [ ! -z \"$COLAB_GPU\" ] && wget -nc --no-check-certificate https://download.pytorch.org/tutorial/hymenoptera_data.zip -P datasets\n",
    "! [ ! -z \"$COLAB_GPU\" ] && unzip -u datasets/hymenoptera_data.zip -d datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from urllib import request\n",
    "from zipfile import ZipFile\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "from torchvision import datasets, models, transforms\n",
    "\n",
    "from skorch import NeuralNetClassifier\n",
    "from skorch.helper import predefined_split\n",
    "\n",
    "torch.manual_seed(360);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we begin, lets download the data needed for this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data has been downloaded and extracted to datasets.\n"
     ]
    }
   ],
   "source": [
    "def download_and_extract_data(dataset_dir='datasets'):\n",
    "    data_zip = os.path.join(dataset_dir, 'hymenoptera_data.zip')\n",
    "    data_path = os.path.join(dataset_dir, 'hymenoptera_data')\n",
    "    url = \"https://download.pytorch.org/tutorial/hymenoptera_data.zip\"\n",
    "\n",
    "    if not os.path.exists(data_path):\n",
    "        if not os.path.exists(data_zip):\n",
    "            print(\"Starting to download data...\")\n",
    "            data = request.urlopen(url, timeout=15).read()\n",
    "            with open(data_zip, 'wb') as f:\n",
    "                f.write(data)\n",
    "\n",
    "        print(\"Starting to extract data...\")\n",
    "        with ZipFile(data_zip, 'r') as zip_f:\n",
    "            zip_f.extractall(dataset_dir)\n",
    "        \n",
    "    print(\"Data has been downloaded and extracted to {}.\".format(dataset_dir))\n",
    "    \n",
    "download_and_extract_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Problem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are going to train a neutral network to classify **ants** and **bees**. The dataset consist of 120 training images and 75 validiation images for each class. First we create the training and validiation datasets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = 'datasets/hymenoptera_data'\n",
    "train_transforms = transforms.Compose([\n",
    "    transforms.RandomResizedCrop(224),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], \n",
    "                         [0.229, 0.224, 0.225])\n",
    "])\n",
    "val_transforms = transforms.Compose([\n",
    "    transforms.Resize(256),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], \n",
    "                         [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "train_ds = datasets.ImageFolder(\n",
    "    os.path.join(data_dir, 'train'), train_transforms)\n",
    "val_ds = datasets.ImageFolder(\n",
    "    os.path.join(data_dir, 'val'), val_transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: `[0.485, 0.456, 0.406]`, and standard deviation: `[0.229, 0.224, 0.225]`. These values are the means and standard deviations of the ImageNet images. We used these values because the pretrained model was trained on ImageNet."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading pretrained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use a pretrained `ResNet18` neutral network model with its final layer replaced with a fully connected layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PretrainedModel(nn.Module):\n",
    "    def __init__(self, output_features):\n",
    "        super().__init__()\n",
    "        model = models.resnet18(pretrained=True)\n",
    "        num_ftrs = model.fc.in_features\n",
    "        model.fc = nn.Linear(num_ftrs, output_features)\n",
    "        self.model = model\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we are training a binary classifier, the output of the final fully connected layer has size 2."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using skorch's API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we will create a `skorch.NeuralNetClassifier` to solve our classification problem. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we create a `LRScheduler` callback which is a learning rate scheduler that uses `torch.optim.lr_scheduler.StepLR` to scale learning rates by `gamma=0.1` every 7 steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch.callbacks import LRScheduler\n",
    "\n",
    "lrscheduler = LRScheduler(\n",
    "    policy='StepLR', step_size=7, gamma=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we create a `Checkpoint` callback which saves the best model by by monitoring the validation accuracy. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch.callbacks import Checkpoint\n",
    "\n",
    "checkpoint = Checkpoint(\n",
    "    f_params='best_model.pt', monitor='valid_acc_best')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, we create a `Freezer` to freeze all weights besides the final layer named `model.fc`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch.callbacks import Freezer\n",
    "\n",
    "freezer = Freezer(lambda x: not x.startswith('model.fc'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### skorch.NeutralNetClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With all the preparations out of the way, we can now define our `NeutralNetClassifier`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = NeuralNetClassifier(\n",
    "    PretrainedModel, \n",
    "    criterion=nn.CrossEntropyLoss,\n",
    "    lr=0.001,\n",
    "    batch_size=4,\n",
    "    max_epochs=25,\n",
    "    module__output_features=2,\n",
    "    optimizer=optim.SGD,\n",
    "    optimizer__momentum=0.9,\n",
    "    iterator_train__shuffle=True,\n",
    "    iterator_train__num_workers=4,\n",
    "    iterator_valid__shuffle=True,\n",
    "    iterator_valid__num_workers=4,\n",
    "    train_split=predefined_split(val_ds),\n",
    "    callbacks=[lrscheduler, checkpoint, freezer],\n",
    "    device='cuda' # comment to train on cpu\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That is quite a few parameters! Lets walk through each one:\n",
    "\n",
    "1. `model_ft`: Our `ResNet18` neutral network\n",
    "2. `criterion=nn.CrossEntropyLoss`: loss function\n",
    "3. `lr`: Initial learning rate\n",
    "4. `batch_size`: Size of a batch\n",
    "5. `max_epochs`: Number of epochs to train\n",
    "6. `module__output_features`: Used by `__init__` in our `PretrainedModel` class to set the number of classes.\n",
    "7. `optimizer`: Our optimizer\n",
    "8. `optimizer__momentum`: The initial momentum\n",
    "9. `iterator_{train,valid}__{shuffle,num_workers}`: Parameters that are passed to the dataloader.\n",
    "10. `train_split`: A wrapper around `val_ds` to use our validation dataset.\n",
    "11. `callbacks`: Our callbacks \n",
    "12. `device`: Set to `cuda` to train on gpu.\n",
    "\n",
    "Now we are ready to train our neutral network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss    cp     dur\n",
      "-------  ------------  -----------  ------------  ----  ------\n",
      "      1        \u001b[36m0.8220\u001b[0m       \u001b[32m0.9150\u001b[0m        \u001b[35m0.2294\u001b[0m     +  1.7953\n",
      "      2        \u001b[36m0.4949\u001b[0m       \u001b[32m0.9346\u001b[0m        \u001b[35m0.2116\u001b[0m     +  0.9276\n",
      "      3        \u001b[36m0.4873\u001b[0m       0.8105        0.4593        0.9309\n",
      "      4        0.5291       \u001b[32m0.9477\u001b[0m        \u001b[35m0.1725\u001b[0m     +  0.9292\n",
      "      5        \u001b[36m0.4530\u001b[0m       0.9216        0.2275        0.9046\n",
      "      6        \u001b[36m0.3869\u001b[0m       0.9412        \u001b[35m0.1697\u001b[0m        0.9121\n",
      "      7        \u001b[36m0.2903\u001b[0m       \u001b[32m0.9608\u001b[0m        0.1778     +  0.9504\n",
      "      8        0.3000       0.9477        0.1769        0.9169\n",
      "      9        0.4068       0.9542        0.1830        0.9312\n",
      "     10        0.5076       0.9281        0.1953        1.0024\n",
      "     11        0.3271       0.9346        0.1911        0.9144\n",
      "     12        0.3728       0.9281        0.2180        0.8806\n",
      "     13        \u001b[36m0.2847\u001b[0m       0.9477        0.1847        0.9466\n",
      "     14        0.3526       0.9216        0.2333        0.9141\n",
      "     15        0.3254       0.9281        0.1802        0.8951\n",
      "     16        0.3407       0.9477        0.1888        0.8973\n",
      "     17        \u001b[36m0.2498\u001b[0m       0.9346        0.1931        0.9159\n",
      "     18        0.4421       0.9477        0.1848        0.9186\n",
      "     19        0.3548       0.9216        0.2010        0.8960\n",
      "     20        0.3037       0.9281        0.2188        0.9178\n",
      "     21        0.3454       0.9542        0.1837        0.9184\n",
      "     22        0.3227       0.9412        0.1732        0.9115\n",
      "     23        0.2595       0.9542        0.1765        0.9040\n",
      "     24        0.3164       0.9477        0.1794        0.9101\n",
      "     25        0.2607       0.9412        0.1934        0.9493\n"
     ]
    }
   ],
   "source": [
    "net.fit(train_ds, y=None);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The best model is stored at `best_model.pt`, with a validiation accuracy of roughly 0.96.\n",
    "\n",
    "Congrualations! You now know how to finetune a neutral network using `skorch`. Feel free to explore the other tutorials to learn more about using `skorch`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
